{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Perturbed optimizers\n",
        "\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google-deepmind/optax/blob/main/examples/perturbations.ipynb)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l98gJVtFFJPB"
      },
      "source": [
        "We review in this notebook a universal method to transform any function $f$ mapping a pytree to another pytree to a differentiable approximation $f_\\varepsilon$, using pertutbations following the method of [Berthet et al. (2020)](https://arxiv.org/abs/2002.08676).\n",
        "\n",
        "For a random $Z$ drawn from a distribution with continuous positive distribution $\\mu$ and a function $f: X \\to Y$, its perturbed approximation defined for any $x \\in X$ by\n",
        "\n",
        "$$f_\\varepsilon(x) = \\mathbf{E}[f (x + \\varepsilon Z )]\\, .$$\n",
        "\n",
        "We illustrate here on some examples, including the case of an optimizer function $y^*$ over $C$ defined for any cost $\\theta \\in \\mathbb{R}^d$ by\n",
        "\n",
        "$$y^*(\\theta) = \\mathop{\\mathrm{arg\\,max}}_{y \\in C} \\langle y, \\theta \\rangle\\, .$$\n",
        "\n",
        "In this case, the perturbed optimizer is given by\n",
        "\n",
        "$$y_\\varepsilon^*(\\theta) = \\mathbf{E}[\\mathop{\\mathrm{arg\\,max}}_{y\\in C} \\langle y, \\theta + \\varepsilon Z \\rangle]\\, .$$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "S6tLyyy9VCEw"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import operator\n",
        "from jax import tree_util as jtu\n",
        "\n",
        "from optax import tree_utils as otu\n",
        "from optax import perturbations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EmYn_jNUFfw2"
      },
      "source": [
        "# Argmax one-hot"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BmHzgCn7FJPC"
      },
      "source": [
        "We consider an optimizer, such as the following `argmax_one_hot` function. It transforms a real-valued vector into a binary vector with a 1 in the coefficient with largest magnitude and 0 elsewhere. It corresponds to $y^*$ for $C$ being the unit simplex. We run it on an example input `values`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84N-wAJ8GDK2"
      },
      "source": [
        "## One-hot function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "kMZnzhX4FjGj"
      },
      "outputs": [],
      "source": [
        "def argmax_one_hot(x, axis=-1):\n",
        "  return jax.nn.one_hot(jnp.argmax(x, axis=axis), x.shape[axis])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "iynCk8734Wiz"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[0. 1. 0. 0. 0.]\n"
          ]
        }
      ],
      "source": [
        "values = jnp.array([-0.6, 1.9, -0.2, 1.1, -1.0])\n",
        "\n",
        "one_hot_vec = argmax_one_hot(values)\n",
        "print(one_hot_vec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6rbNt-6zGb-J"
      },
      "source": [
        "## One-hot with pertubations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5lvIhCV1FJPD"
      },
      "source": [
        "Our implementation transforms the `argmax_one_hot` function into a perturbed one that we call `pert_one_hot`. In this case we use Gumbel noise for the perturbation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "7hQz6zuPwkpZ"
      },
      "outputs": [],
      "source": [
        "N_SAMPLES = 100\n",
        "SIGMA = 0.5\n",
        "GUMBEL = perturbations.Gumbel()\n",
        "\n",
        "rng = jax.random.PRNGKey(1)\n",
        "pert_one_hot = perturbations.make_perturbed_fun(fun=argmax_one_hot,\n",
        "                                                num_samples=N_SAMPLES,\n",
        "                                                sigma=SIGMA,\n",
        "                                                noise=GUMBEL)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I2VBjnSUFJPD"
      },
      "source": [
        "In this particular case, it is equal to the usual [softmax function](https://en.wikipedia.org/wiki/Softmax_function). This is not always true, in general there is no closed form for $y_\\varepsilon^*$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "f2gDpghJYZ33"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "computation with 100 samples, sigma = 0.5\n",
            "perturbed argmax = [0.01 0.8  0.02 0.16 0.01]\n",
            "softmax = [0.00549293 0.8152234  0.01222475 0.16459078 0.00246813]\n",
            "square norm of softmax = 8.32e-01\n",
            "square norm of difference = 1.98e-02\n"
          ]
        }
      ],
      "source": [
        "rngs = jax.random.split(rng, 2)\n",
        "\n",
        "rng = rngs[0]\n",
        "\n",
        "pert_argmax = pert_one_hot(values, rng)\n",
        "print(f'computation with {N_SAMPLES} samples, sigma = {SIGMA}')\n",
        "print(f'perturbed argmax = {pert_argmax}')\n",
        "jax.nn.softmax(values/SIGMA)\n",
        "soft_max = jax.nn.softmax(values/SIGMA)\n",
        "print(f'softmax = {soft_max}')\n",
        "print(f'square norm of softmax = {jnp.linalg.norm(soft_max):.2e}')\n",
        "print(f'square norm of difference = {jnp.linalg.norm(pert_argmax - soft_max):.2e}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2U7rhtEAGpMV"
      },
      "source": [
        "## Gradients for one-hot with perturbations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ldxsWvLmFJPD"
      },
      "source": [
        "The perturbed optimizer $y_\\varepsilon^*$ is differentiable, and its gradient can be computed with stochastic estimation automatically, using `jax.grad`.\n",
        "\n",
        "We create a scalar loss `loss_simplex` of the perturbed optimizer $y^*_\\varepsilon$\n",
        "\n",
        "$$\\ell_\\text{simplex}(y_{\\text{true}} = y_\\varepsilon^*; y_{\\text{true}})$$  \n",
        "\n",
        "For `values` equal to a vector $\\theta$, we can compute gradients of\n",
        "\n",
        "$$\\ell(\\theta) = \\ell_\\text{simplex}(y_\\varepsilon^*(\\theta); y_{\\text{true}})$$\n",
        "with respect to `values`, automatically."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "7H1LD4QhGtFI"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Array(0.53180003, dtype=float32)"
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Example loss function\n",
        "\n",
        "def loss_simplex(values, rng):\n",
        "  n = values.shape[0]\n",
        "  v_true = jnp.arange(n) + 2\n",
        "  y_true = v_true / jnp.sum(v_true)\n",
        "  y_pred = pert_one_hot(values, rng)\n",
        "  return jnp.sum((y_true - y_pred) ** 2)\n",
        "\n",
        "loss_simplex(values, rngs[1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CM2poXb4FJPD"
      },
      "source": [
        "We can compute the gradient of $\\ell$ directly\n",
        "\n",
        "$$\\nabla_\\theta \\ell(\\theta) = \\partial_\\theta y^*_\\varepsilon(\\theta) \\cdot \\nabla_1 \\ell_{\\text{simplex}}(y^*_\\varepsilon(\\theta); y_{\\text{true}})$$\n",
        "\n",
        "The computation of the jacobian $\\partial_\\theta y^*_\\varepsilon(\\theta)$ is implemented automatically, using an estimation method given by [Berthet et al. (2020)](https://arxiv.org/abs/2002.08676), [Prop. 3.1]."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "tjQatCE3GtFJ"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[-0.18402426  0.3660836  -0.0462517  -0.09384193 -0.02169794]\n"
          ]
        }
      ],
      "source": [
        "# Gradient of the loss w.r.t input values\n",
        "\n",
        "gradient = jax.grad(loss_simplex)(values, rngs[1])\n",
        "print(gradient)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Eh2Qt97AFJPD"
      },
      "source": [
        "We illustrate the use of this method by running 200 steps of gradient descent on $\\theta_t$ so that it minimizes this loss."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "MuNE2RX0GtFJ"
      },
      "outputs": [],
      "source": [
        "# Doing 200 steps of gradient descent on the values to have the desired ranks\n",
        "\n",
        "steps = 200\n",
        "values_t = values\n",
        "eta = 0.5\n",
        "\n",
        "grad_func = jax.jit(jax.grad(loss_simplex))\n",
        "\n",
        "for t in range(steps):\n",
        "  rngs = jax.random.split(rngs[1], 2)\n",
        "  values_t = values_t - eta * grad_func(values_t, rngs[1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "29TWHiH0GtFJ"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "initial values = [-0.6  1.9 -0.2  1.1 -1. ]\n",
            "initial one-hot = [0. 1. 0. 0. 0.]\n",
            "initial diff. one-hot = [0.   0.85 0.01 0.14 0.  ]\n",
            "\n",
            "values after GD = [-0.11319756  0.1024728   0.23011085  0.3087453   0.37857407]\n",
            "ranks after GD = [0. 0. 0. 0. 1.]\n",
            "diff. one-hot after GD = [0.09 0.18 0.21 0.19 0.33]\n",
            "target diff. one-hot = [0.1  0.15 0.2  0.25 0.3 ]\n"
          ]
        }
      ],
      "source": [
        "rngs = jax.random.split(rngs[1], 2)\n",
        "\n",
        "n = values.shape[0]\n",
        "v_true = jnp.arange(n) + 2\n",
        "y_true = v_true / jnp.sum(v_true)\n",
        "\n",
        "print(f'initial values = {values}')\n",
        "print(f'initial one-hot = {argmax_one_hot(values)}')\n",
        "print(f'initial diff. one-hot = {pert_one_hot(values, rngs[0])}')\n",
        "print()\n",
        "print(f'values after GD = {values_t}')\n",
        "print(f'ranks after GD = {argmax_one_hot(values_t)}')\n",
        "print(f'diff. one-hot after GD = {pert_one_hot(values_t, rngs[1])}')\n",
        "print(f'target diff. one-hot = {y_true}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Vyh_a1bZT-s"
      },
      "source": [
        "# Differentiable ranking"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QmVAjbJxFzUA"
      },
      "source": [
        "## Ranking function"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gyapGu77FJPE"
      },
      "source": [
        "We consider an optimizer, such as the following `ranking` function. It transforms a real-valued vector of size $n$ into a vector with coefficients being a permutation of $\\{0,\\ldots, n-1\\}$ corresponding to the order of the coefficients of the original vector. It corresponds to $y^*$ for $C$ being the permutahedron. We run it on an example input `values`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "-NKbR6TlZUTG"
      },
      "outputs": [],
      "source": [
        "# Function outputting a vector of ranks\n",
        "\n",
        "def ranking(values):\n",
        "  return jnp.argsort(jnp.argsort(values))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "iU69uMAoZncY"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "values = [ 0.18784384 -1.2833426   0.6494181   1.2490594   0.24447003 -0.11744965]\n",
            "ranking = [2 0 4 5 3 1]\n"
          ]
        }
      ],
      "source": [
        "# Example on random values\n",
        "\n",
        "n = 6\n",
        "\n",
        "rng = jax.random.PRNGKey(0)\n",
        "values = jax.random.normal(rng, (n,))\n",
        "\n",
        "print(f'values = {values}')\n",
        "print(f'ranking = {ranking(values)}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5j1Vgfz_bb9u"
      },
      "source": [
        "## Ranking with perturbations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eu2wfbNuFJPE"
      },
      "source": [
        "As above, our implementation transforms this function into a perturbed one that we call `pert_ranking`. In this case we use Gumbel noise for the perturbation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "Equ3_gDPbf5n"
      },
      "outputs": [],
      "source": [
        "N_SAMPLES = 100\n",
        "SIGMA = 0.2\n",
        "GUMBEL = perturbations.Gumbel()\n",
        "\n",
        "pert_ranking = perturbations.make_perturbed_fun(ranking,\n",
        "                                                num_samples=N_SAMPLES,\n",
        "                                                sigma=SIGMA,\n",
        "                                                noise=GUMBEL)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "vMj-Dnudby_a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "values = [ 0.18784384 -1.2833426   0.6494181   1.2490594   0.24447003 -0.11744965]\n",
            "diff_ranks = [2.41 0.   3.92 4.93 2.36 1.38]\n"
          ]
        }
      ],
      "source": [
        "# Expectation of the perturbed ranks on these values\n",
        "\n",
        "rngs = jax.random.split(rng, 2)\n",
        "\n",
        "diff_ranks = pert_ranking(values, rngs[0])\n",
        "print(f'values = {values}')\n",
        "\n",
        "print(f'diff_ranks = {diff_ranks}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aH6Ew85koQvU"
      },
      "source": [
        "## Gradients for ranking with perturbations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UDZOEt18FJPE"
      },
      "source": [
        "As above, the perturbed optimizer $y_\\varepsilon^*$ is differentiable, and its gradient can be computed with stochastic estimation automatically, using `jax.grad`.\n",
        "\n",
        "We showcase this on a loss of $y_\\varepsilon(\\theta)$ that can be directly differentiated w.r.t. the `values` equal to $\\theta$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "O-T8y6N8cHzF"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "29.7414\n"
          ]
        }
      ],
      "source": [
        "# Example loss function\n",
        "\n",
        "def loss_example(values, rng):\n",
        "  n = values.shape[0]\n",
        "  y_true = ranking(jnp.arange(n))\n",
        "  y_pred = pert_ranking(values, rng)\n",
        "  return jnp.sum((y_true - y_pred) ** 2)\n",
        "\n",
        "print(loss_example(values, rngs[1]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "v7nzNwP-e68q"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[ 5.145397   4.239827   1.2263288  2.1130772  1.2175393 -3.1376717]\n"
          ]
        }
      ],
      "source": [
        "# Gradient of the objective w.r.t input values\n",
        "\n",
        "gradient = jax.grad(loss_example)(values, rngs[1])\n",
        "print(gradient)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aC7IzKADFJPE"
      },
      "source": [
        "As above, we showcase this example on gradient descent to minimize this loss."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "0UObBP3QfCqq"
      },
      "outputs": [],
      "source": [
        "steps = 20\n",
        "values_t = values\n",
        "eta = 0.1\n",
        "\n",
        "grad_func = jax.jit(jax.grad(loss_example))\n",
        "\n",
        "for t in range(steps):\n",
        "  rngs = jax.random.split(rngs[1], 2)\n",
        "  values_t = values_t - eta * grad_func(values_t, rngs[1])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "p4iNxMoQmZRa"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "initial values = [ 0.18784384 -1.2833426   0.6494181   1.2490594   0.24447003 -0.11744965]\n",
            "initial ranks = [2 0 4 5 3 1]\n",
            "initial diff. ranks = [2.3  0.   3.81 4.97 2.55 1.37]\n",
            "\n",
            "values after GD = [-1.2940687  -0.8644269  -0.17311236  0.39050597  0.69519895  1.135534  ]\n",
            "ranks after GD = [0 1 2 3 4 5]\n",
            "diff. ranks after GD = [0.12 0.9  2.07 3.04 4.   4.87]\n",
            "target diff. ranks = [0 1 2 3 4 5]\n"
          ]
        }
      ],
      "source": [
        "rngs = jax.random.split(rngs[1], 2)\n",
        "\n",
        "y_true = ranking(jnp.arange(n))\n",
        "\n",
        "print(f'initial values = {values}')\n",
        "print(f'initial ranks = {ranking(values)}')\n",
        "print(f'initial diff. ranks = {pert_ranking(values, rngs[0])}')\n",
        "print()\n",
        "print(f'values after GD = {values_t}')\n",
        "print(f'ranks after GD = {ranking(values_t)}')\n",
        "print(f'diff. ranks after GD = {pert_ranking(values_t, rngs[1])}')\n",
        "print(f'target diff. ranks = {y_true}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P537S89ZlDQR"
      },
      "source": [
        "# General input / outputs (Pytrees)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V62NPuUvSHk8"
      },
      "source": [
        "This method can be applied to any function taking pytrees as input and output in the forward mode, and can also be used to compute derivatives, as illustrated below"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "0Bz35ZWQpeB7"
      },
      "outputs": [],
      "source": [
        "tree_a = (jnp.array((0.1, 0.4, 0.5)),\n",
        "          {'k1': jnp.array((0.1, 0.2)),\n",
        "           'k2': jnp.array((0.1, 0.1))},\n",
        "          jnp.array((0.4, 0.3, 0.2, 0.1)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UxcczrOZhCZJ"
      },
      "source": [
        "## Tree argmax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hUQqWJYfgrag"
      },
      "source": [
        "This piecewise constant function applies the argmax to every leaf array of the pytree"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "szID1S5Jg_LL"
      },
      "outputs": [],
      "source": [
        "argmax_tree = lambda x: jtu.tree_map(argmax_one_hot, x)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "id": "KPfzjdSJxP4G"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(Array([0., 0., 1.], dtype=float32),\n",
              " {'k1': Array([0., 1.], dtype=float32), 'k2': Array([1., 0.], dtype=float32)},\n",
              " Array([1., 0., 0., 0.], dtype=float32))"
            ]
          },
          "execution_count": 20,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "argmax_tree(tree_a)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hOmGdkW0g2-6"
      },
      "source": [
        "The perturbed approximation applies a perturbed softmax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "oKuD_cElxSDd"
      },
      "outputs": [],
      "source": [
        "N_SAMPLES = 100\n",
        "sigma = 1.0\n",
        "\n",
        "pert_argmax_fun = perturbations.make_perturbed_fun(argmax_tree,\n",
        "                                                   num_samples=N_SAMPLES,\n",
        "                                                   sigma=SIGMA)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "_fnpfpVBxYSQ"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "(Array([0.08, 0.34, 0.58], dtype=float32),\n",
              " {'k1': Array([0.31, 0.69], dtype=float32),\n",
              "  'k2': Array([0.51, 0.49], dtype=float32)},\n",
              " Array([0.38, 0.3 , 0.2 , 0.12], dtype=float32))"
            ]
          },
          "execution_count": 22,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "pert_argmax_fun(tree_a, rng)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MQOh_iZmhLVS"
      },
      "source": [
        "## Scalar loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "zW0U0DW1xbAV"
      },
      "outputs": [],
      "source": [
        "def pert_loss(inputs, rng):\n",
        "  pert_softmax = pert_argmax_fun(inputs, rng)\n",
        "  argmax = argmax_tree(inputs)\n",
        "  diffs = jtu.tree_map(lambda x, y: jnp.sum((x - y) ** 2 / 4), argmax, pert_softmax)\n",
        "  return jtu.tree_reduce(operator.add, diffs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "id": "bWjXKeMSxodX"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "initial loss value = 0.375\n"
          ]
        }
      ],
      "source": [
        "init_loss = pert_loss(tree_a, rng)\n",
        "\n",
        "print(f'initial loss value = {init_loss:.3f}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eXobsx6bhRb8"
      },
      "source": [
        "## Gradient computation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kydUMAachVgp"
      },
      "source": [
        "The gradient of the scalar loss can be evaluated"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "id": "vryBVzPsxqlI"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Gradient of the scalar loss\n",
            "\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "(Array([-0.01304852,  0.06675826, -0.04797449], dtype=float32),\n",
              " {'k1': Array([ 0.02245465, -0.08815713], dtype=float32),\n",
              "  'k2': Array([-0.08147696,  0.2535689 ], dtype=float32)},\n",
              " Array([-0.07846132,  0.14177498, -0.06175202,  0.08064784], dtype=float32))"
            ]
          },
          "execution_count": 25,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "grad = jax.grad(pert_loss)(tree_a, rng)\n",
        "\n",
        "print('Gradient of the scalar loss')\n",
        "print()\n",
        "grad"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SpbnFapshcaM"
      },
      "source": [
        "A small step in the gradient direction reduces the value"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "id": "EkwIh76L1Azl"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "initial loss value = 0.375\n",
            "loss after gradient step = 0.333\n"
          ]
        }
      ],
      "source": [
        "eta = 1e-1\n",
        "\n",
        "loss_step = pert_loss(otu.tree_add_scalar_mul(tree_a, -eta, grad), rng)\n",
        "\n",
        "print(f'initial loss value = {init_loss:.3f}')\n",
        "print(f'loss after gradient step = {loss_step:.3f}')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1CDnfXElxMMd_144LVTaeVCLay3-RXJrA",
          "timestamp": 1726135261814
        },
        {
          "file_id": "1i83GFtgxkGQ6t-WTG0Bz9SMGMvzmVfc7",
          "timestamp": 1726129388187
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
